from __future__ import annotations

import logging
import typing
from collections.abc import Collection
from typing import Dict, Any, List, cast

from checkov.common.graph.graph_builder.graph_components.attribute_names import CustomAttributes
from checkov.common.graph.graph_builder.utils import calculate_hash, join_trimmed_strings
from checkov.common.graph.graph_builder.variable_rendering.breadcrumb_metadata import BreadcrumbMetadata
from checkov.common.util.data_structures_utils import pickle_deepcopy

from bc_jsonpath_ng.ext import parse

if typing.TYPE_CHECKING:
    from bc_jsonpath_ng import JSONPath


class Block:
    __slots__ = (
        "attributes",
        "block_type",
        "breadcrumbs",
        "changed_attributes",
        "config",
        "id",
        "name",
        "path",
        "source",
        "has_dynamic_block",
        "dynamic_attributes",
        "foreach_attrs"
    )

    jsonpath_parsed_statement_cache: "dict[str, JSONPath]" = {}  # noqa: CCE003  # global cache

    def __init__(
            self,
            name: str,
            config: Dict[str, Any],
            path: str,
            block_type: str,
            attributes: Dict[str, Any],
            id: str = "",
            source: str = "",
            has_dynamic_block: bool = False,
            dynamic_attributes: dict[str, Any] | None = None
    ) -> None:
        """
            :param name: unique name given to the block, for example
            :param config: the section in tf_definitions that belong to this block
            :param path: the file location of the block
            :param block_type: str
            :param attributes: dictionary of the block's original attributes in the origin file
        """
        self.name = name
        self.config = pickle_deepcopy(config)
        self.path = path
        self.block_type = block_type
        self.attributes = attributes
        self.id = id
        self.source = source
        self.changed_attributes: Dict[str, List[Any]] = {}
        self.breadcrumbs: Dict[str, List[Dict[str, Any]]] = {}

        attributes_to_add = self._extract_inner_attributes(has_dynamic_block, dynamic_attributes)
        self.attributes.update(attributes_to_add)

    def _extract_inner_attributes(self, has_dynamic_block: bool = False, dynamic_attributes: dict[str, Any] | None = None) -> Dict[str, Any]:
        attributes_to_add = {}
        for attribute_key, attribute_value in self.attributes.items():
            if has_dynamic_block and attribute_key in dynamic_attributes.keys():  # type: ignore
                continue
            if self.should_run_get_inner_attributes(attribute_value):
                inner_attributes = self.get_inner_attributes(
                    attribute_key=attribute_key,
                    attribute_value=attribute_value,
                )
                attributes_to_add.update(inner_attributes)
        return attributes_to_add

    def should_run_get_inner_attributes(self, attribute_value: Any) -> bool:
        return isinstance(attribute_value, dict) or (isinstance(attribute_value, list) and len(attribute_value) > 0
                                                     and isinstance(attribute_value[0], dict))

    def __str__(self) -> str:
        return f"{self.block_type}: {self.name}"

    def get_attribute_dict(self, add_hash: bool = True) -> Dict[str, Any]:
        """
           :return: map of all the block's native attributes (from the source file),
           combined with the attributes generated by the module builder.
           If the attributes are not a primitive type, they are converted to strings.
           """
        base_attributes = self.get_base_attributes()
        self.get_origin_attributes(base_attributes)

        if self.changed_attributes:
            # add changed attributes only for calculating the hash
            base_attributes["changed_attributes"] = sorted(self.changed_attributes.keys())

        if self.breadcrumbs:
            sorted_breadcrumbs = dict(sorted(self.breadcrumbs.items()))
            base_attributes[CustomAttributes.RENDERING_BREADCRUMBS] = sorted_breadcrumbs

        if add_hash:
            base_attributes[CustomAttributes.HASH] = calculate_hash(base_attributes)

        if "changed_attributes" in base_attributes:
            # removed changed attributes if it was added previously for calculating hash.
            del base_attributes["changed_attributes"]

        return base_attributes

    def get_origin_attributes(self, base_attributes: Dict[str, Any]) -> None:
        for attribute_key in list(self.attributes.keys()):
            attribute_value = self.attributes[attribute_key]
            if isinstance(attribute_value, list) and len(attribute_value) == 1:
                if '.' not in attribute_key:
                    attribute_value = attribute_value[0]
            # needs to be checked before adding anything to 'base_attributes'
            if attribute_key == "self":
                base_attributes["self_"] = attribute_value
                continue
            if isinstance(attribute_value, (list, dict)):
                inner_attributes = self.get_inner_attributes(attribute_key, attribute_value, False)
                base_attributes.update(inner_attributes)

            base_attributes[attribute_key] = attribute_value

    def get_hash(self) -> str:
        attributes_dict = self.get_attribute_dict()
        return cast("str", attributes_dict.get(CustomAttributes.HASH, ""))

    def update_attribute(
        self,
        attribute_key: str,
        attribute_value: Any,
        change_origin_id: int | None,
        previous_breadcrumbs: list[BreadcrumbMetadata],
        attribute_at_dest: str | None,
        transform_step: bool = False,
    ) -> None:
        self.update_inner_attribute(
            attribute_key=attribute_key,
            nested_attributes=self.attributes,
            value_to_update=attribute_value
        )

        if (
            self._should_add_previous_breadcrumbs(change_origin_id, previous_breadcrumbs, attribute_at_dest)
            and change_origin_id is not None
        ):
            previous_breadcrumbs.append(BreadcrumbMetadata(change_origin_id, attribute_at_dest))

        # update the numbered attributes, if the new value is a list
        if attribute_value and isinstance(attribute_value, list):
            self.update_list_attribute(attribute_key=attribute_key, attribute_value=attribute_value)

        attribute_key_parts = attribute_key.split(".")
        if len(attribute_key_parts) == 1:
            self.attributes[attribute_key] = attribute_value
            if self._should_set_changed_attributes(change_origin_id, attribute_at_dest):
                self.changed_attributes[attribute_key] = previous_breadcrumbs
            return
        for i in range(len(attribute_key_parts)):
            key = join_trimmed_strings(char_to_join=".", str_lst=attribute_key_parts, num_to_trim=i)
            if key.find(".") > -1:
                additional_changed_attributes = self.extract_additional_changed_attributes(key)
                if key in self.attributes and isinstance(self.attributes[key], dict) and key != attribute_key:
                    try:
                        self._update_attribute_based_on_jsonpath_key(attribute_value, key)
                    except Exception as e:
                        logging.debug(f"Failed updating attribute for key: {key} and value {attribute_value}."
                                      f"Falling back to explicitly setting it."
                                      f"Exception - {e}")
                        self.attributes[key] = attribute_value
                else:
                    self.attributes[key] = attribute_value
                end_key_part = attribute_key_parts[len(attribute_key_parts) - 1 - i]
                if transform_step and end_key_part in ("1", "2"):
                    # if condition logic during the transform step breaks the values
                    return
                attribute_value = {end_key_part: attribute_value}
                if self._should_set_changed_attributes(change_origin_id, attribute_at_dest):
                    self.changed_attributes[key] = previous_breadcrumbs
                    if additional_changed_attributes:
                        for changed_attribute in additional_changed_attributes:
                            self.changed_attributes[changed_attribute] = previous_breadcrumbs

    def _update_attribute_based_on_jsonpath_key(self, attribute_value: Any, key: str) -> None:
        """
        When updating all the attributes we might try to update a specific attribute inside a complex object,
        so we use jsonpath to refer to the specific location only.
        """
        if key not in Block.jsonpath_parsed_statement_cache:
            jsonpath_key = self._get_jsonpath_key(key)
            expr = parse(jsonpath_key)
            Block.jsonpath_parsed_statement_cache[key] = expr
        else:
            expr = Block.jsonpath_parsed_statement_cache[key]
        match = expr.find(self.attributes)
        if match:
            match[0].value = attribute_value
        return None

    @staticmethod
    def _get_jsonpath_key(key: str) -> str:
        jsonpath_key = "$."
        key_parts = key.split(".")
        updated_parts = []
        for part in key_parts:
            if part.isnumeric():
                updated_parts.append(f"[{part}]")
            elif "/" in part or "::" in part:
                updated_parts.append(f'"{part}"')
            else:
                updated_parts.append(part)
        jsonpath_key += ".".join(updated_parts)
        # Replace .0 with [0] to match jsonpath style
        jsonpath_key = jsonpath_key.replace(".[", "[")
        return jsonpath_key

    def update_inner_attribute(
        self, attribute_key: str, nested_attributes: list[Any] | dict[str, Any], value_to_update: Any
    ) -> None:
        split_key = attribute_key.split(".")
        i = 1
        curr_key = ".".join(split_key[0:i])
        if isinstance(nested_attributes, list):
            if curr_key.isnumeric():
                curr_key_int = int(curr_key)
                if curr_key_int < len(nested_attributes):
                    if not isinstance(nested_attributes[curr_key_int], dict):
                        nested_attributes[curr_key_int] = value_to_update
                    else:
                        self.update_inner_attribute(
                            ".".join(split_key[i:]), nested_attributes[curr_key_int], value_to_update
                        )
            else:
                for inner in nested_attributes:
                    self.update_inner_attribute(curr_key, inner, value_to_update)
        elif isinstance(nested_attributes, dict):
            while curr_key not in nested_attributes and i <= len(split_key):
                i += 1
                curr_key = ".".join(split_key[0:i])
            if attribute_key in nested_attributes.keys():
                nested_attributes[attribute_key] = value_to_update
            if len(split_key) == 1 and len(curr_key) > 0:
                nested_attributes[curr_key] = value_to_update
            elif curr_key in nested_attributes.keys():
                self.update_inner_attribute(".".join(split_key[i:]), nested_attributes[curr_key], value_to_update)

    def update_list_attribute(self, attribute_key: str, attribute_value: Any) -> None:
        """Updates list attributes with their index"""

        for idx, value in enumerate(attribute_value):
            self.attributes[f"{attribute_key}.{idx}"] = value

    @staticmethod
    def _should_add_previous_breadcrumbs(
        change_origin_id: int | None, previous_breadcrumbs: list[BreadcrumbMetadata], attribute_at_dest: str | None
    ) -> bool:
        return not previous_breadcrumbs or previous_breadcrumbs[-1].vertex_id != change_origin_id

    def extract_additional_changed_attributes(self, attribute_key: str) -> List[str]:
        """
        override in case of a special case where additional attributes are needed to be tracked included in self.changed_attributes
        and self.breadcrumbs, such as terraform dynamic blocks
        :param attribute_key: JSONPath notation of an attribute key that is used for extraction
        :return: list of the additional attributes, in JSONPath notation
        """
        return []

    @staticmethod
    def _should_set_changed_attributes(change_origin_id: int | None, attribute_at_dest: str | None) -> bool:
        return True

    def get_export_data(self) -> Dict[str, Any]:
        return {"type": self.block_type, "name": self.name, "path": self.path}

    def get_base_attributes(self) -> Dict[str, Any]:
        return {
            CustomAttributes.BLOCK_NAME: self.name,
            CustomAttributes.BLOCK_TYPE: self.block_type,
            CustomAttributes.FILE_PATH: self.path,
            CustomAttributes.CONFIG: self.config,
            CustomAttributes.LABEL: str(self),
            CustomAttributes.ID: self.id,
            CustomAttributes.SOURCE: self.source,
        }

    @classmethod
    def get_inner_attributes(
        cls,
        attribute_key: str,
        attribute_value: str | List[str] | dict[str, Any],
        strip_list: bool = True  # used by subclass
    ) -> dict[str, Any]:
        inner_attributes: dict[str, Any] = {}

        if isinstance(attribute_value, (dict, list)):
            inner_attributes[attribute_key] = [None] * len(attribute_value) if isinstance(attribute_value, list) else {}
            iterator: Collection[int] | Collection[str] = range(len(attribute_value)) if isinstance(
                attribute_value, list
            ) else list(
                attribute_value.keys()
            )
            for key in iterator:
                if key != "":
                    inner_key = f"{attribute_key}.{key}"
                    inner_value = attribute_value[key]  # type:ignore[index]
                    inner_attributes.update(cls.get_inner_attributes(inner_key, inner_value))
                    inner_attributes[attribute_key][key] = inner_attributes[inner_key]
                else:
                    del attribute_value[key]  # type:ignore[arg-type]
        else:
            inner_attributes[attribute_key] = attribute_value
        return inner_attributes
